{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Function is:\n",
    "\n",
    "$$ f(x_0, x_1) = \\sin(x_0) (x_0 + x_1) $$\n",
    "\n",
    "or broken down\n",
    "\n",
    "$$ \\begin{align}\n",
    "z_0 &= x_0 \\\\\n",
    "z_1 &= x_1 \\\\\n",
    "z_2 &= \\sin(z_0) \\\\\n",
    "z_3 &= z_0 + z_1 \\\\\n",
    "z_4 &= z_2 z_3 \\\\\n",
    "\\end{align} $$\n",
    "\n",
    "Its symbolic derivative is:\n",
    "\n",
    "$$ \\nabla f(x_0, x_1) = \\begin{bmatrix}\n",
    "\\cos(x_0) (x_0 + x_1) + \\sin(x_0) \\\\\n",
    "\\sin(x_0)\n",
    "\\end{bmatrix} $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x_0, x_1):\n",
    "    return np.sin(x_0) * (x_0 + x_1)\n",
    "\n",
    "def f_grad(x_0, x_1):\n",
    "    return np.array([\n",
    "        np.cos(x_0) * (x_0 + x_1) + np.sin(x_0),\n",
    "        np.sin(x_0),\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "compute_graph = [\n",
    "    (\"inp\", (0,)),     # 0\n",
    "    (\"inp\", (1,)),     # 1\n",
    "    (\"sin\", (0,)),     # 2\n",
    "    (\"add\", (0, 1)),   # 3\n",
    "    (\"mul\", (2, 3)),   # 4\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "fn_library = {\n",
    "    \"inp\": lambda x: x,\n",
    "    \"sin\": lambda x: np.sin(x),\n",
    "    \"add\": lambda x, y: x + y,\n",
    "    \"mul\": lambda x, y: x * y,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute(graph, inputs):\n",
    "    values = list(inputs)\n",
    "    for operation, indices in graph:\n",
    "        if operation == \"inp\":\n",
    "            continue\n",
    "        args = [values[index] for index in indices]\n",
    "        result = fn_library[operation](*args)\n",
    "        values.append(result)\n",
    "    \n",
    "    return values[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAMPLE_INPUT = (0.6, 1.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.1292849467900707"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f(*SAMPLE_INPUT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.1292849467900707"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compute(compute_graph, SAMPLE_INPUT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inp_backprop_rule(x):\n",
    "    z = x\n",
    "\n",
    "    def inp_pullback(z_cotangent):\n",
    "        x_cotangent = z_cotangent\n",
    "        return (x_cotangent,)\n",
    "    \n",
    "    return z, inp_pullback\n",
    "\n",
    "def sin_backprop_rule(x):\n",
    "    z = np.sin(x)\n",
    "\n",
    "    def sin_pullback(z_cotangent):\n",
    "        x_cotangent = np.cos(x) * z_cotangent\n",
    "        return (x_cotangent,)\n",
    "    \n",
    "    return z, sin_pullback\n",
    "\n",
    "def add_backprop_rule(x, y):\n",
    "    z = x + y\n",
    "\n",
    "    def add_pullback(z_cotangent):\n",
    "        x_cotangent = z_cotangent\n",
    "        y_cotangent = z_cotangent\n",
    "\n",
    "        return (x_cotangent, y_cotangent)\n",
    "    \n",
    "    return z, add_pullback\n",
    "\n",
    "def mul_backprop_rule(x, y):\n",
    "    z = x * y\n",
    "\n",
    "    def mul_pullback(z_cotangent):\n",
    "        x_cotangent = y * z_cotangent\n",
    "        y_cotangent = x * z_cotangent\n",
    "        return (x_cotangent, y_cotangent)\n",
    "    \n",
    "    return z, mul_pullback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "backprop_library = {\n",
    "    \"inp\": inp_backprop_rule,\n",
    "    \"sin\": sin_backprop_rule,\n",
    "    \"add\": add_backprop_rule,\n",
    "    \"mul\": mul_backprop_rule,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vjp(graph, inputs):\n",
    "    values = list(inputs)\n",
    "    pullback_stack = []\n",
    "\n",
    "    # Forward pass\n",
    "    for operation, indices in graph:\n",
    "        if operation == \"inp\":\n",
    "            continue\n",
    "        args = [values[index] for index in indices]\n",
    "        result, pullback_fn = backprop_library[operation](*args)\n",
    "        values.append(result)\n",
    "        pullback_stack.append((pullback_fn, indices))\n",
    "\n",
    "    def pullback(output_cotangent):\n",
    "        cotangent_values = np.zeros(len(values))\n",
    "        cotangent_values[-1] = output_cotangent\n",
    "\n",
    "        for i, (pullback_fn, indices) in enumerate(reversed(pullback_stack)):\n",
    "            current_cotangent_value = cotangent_values[-1 - i]\n",
    "            cotangent_args = pullback_fn(current_cotangent_value)\n",
    "            for index, cotangent in zip(indices, cotangent_args):\n",
    "                cotangent_values[index] += cotangent\n",
    "        \n",
    "        return cotangent_values[:len(inputs)]\n",
    "    \n",
    "    return values[-1], pullback\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "out, back_fn = vjp(compute_graph, SAMPLE_INPUT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.1292849467900707"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2.2153137 , 0.56464247])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "back_fn(1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2.2153137 , 0.56464247])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f_grad(*SAMPLE_INPUT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def value_and_grad(graph, inputs):\n",
    "    out, back_fn = vjp(graph, inputs)\n",
    "    grad = back_fn(1.0)\n",
    "    return out, grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.1292849467900707, array([2.2153137 , 0.56464247]))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "value_and_grad(compute_graph, SAMPLE_INPUT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.1292849467900707, array([2.2153137 , 0.56464247]))"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f(*SAMPLE_INPUT), f_grad(*SAMPLE_INPUT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
